#@title Self-guidance modules
from collections import defaultdict
import torch
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from PIL import Image
import os
from functools import partial

from utils.gaussian_smoothing import GaussianSmoothing
from utils.utils import search_sequence_numpy

from globals import *


def resave_aux_key(module, *args, old_key='attn', new_key='last_attn'):
  module._aux[new_key] = module._aux[old_key]

def stash_to_aux(module, args, kwargs, output, mode, key="last_feats", args_idx=None, kwargs_key=None, fn_to_run=None, last=False):
#   global step, new_timestep, save_aux
#   if new_timestep == False:
  if Globals.get('new_timestep') == False:
    to_save = None
    if mode == "args":
      to_save = input
      if args_idx is not None: to_save = args[args_idx]
    elif mode == "kwargs":
      assert kwargs_key is not None
      to_save = kwargs[kwargs_key]
    elif mode == "output":
      if type(output) is tuple:
        to_save = output[0]
      else:
        to_save = output
    if fn_to_run is not None: to_save = fn_to_run(to_save)
    try:
    #   if not save_aux:
      if not Globals.get('save_aux'):
        len_ = len(module._aux[key])
        del module._aux[key]
        module._aux[key] = [None]*(len_-1) + [to_save]
      else:
        module._aux[key][-1] = to_save
    except:
      try: del module._aux[key]
      except: pass
      module._aux = {key: [to_save]}
  else:
    to_save = None
    if mode == "args":
      to_save = input
      if args_idx is not None: to_save = args[args_idx]
    elif mode == "kwargs":
      assert kwargs_key is not None
      to_save = kwargs[kwargs_key]
    elif mode == "output":
      if type(output) is tuple:
        to_save = output[0]
      else:
        to_save = output
    if fn_to_run is not None: to_save = fn_to_run(to_save)
    try:
    #   if not save_aux:
      if not Globals.get('save_aux'):
        len_ = len(module._aux[key])
        del module._aux[key]
        module._aux[key] = [None]*len_ + [to_save]
      else:
        module._aux[key][-1] = module._aux[key][-1].cpu()
        module._aux[key].append(to_save)
    except:
      try: del module._aux[key]
      except: pass
      module._aux = {key: [to_save]}


class SelfGuidanceSDXLPipeline(StableDiffusionXLPipeline):
    def get_sg_aux(self, cfg=True, transpose=True):
      aux = defaultdict(dict)
      for name, aux_module in self.unet.named_modules():
        try:
          module_aux = aux_module._aux
          if transpose:
            for k, v in module_aux.items():
              if cfg: v = torch.utils._pytree.tree_map(lambda vv: vv.chunk(2)[1] if vv is not None else None, v)
              aux[k][name] = v
          else:
            aux[name] = module_aux
            if cfg:
              aux[name] = {k: torch.utils._pytree.tree_map(lambda vv: vv.chunk(2)[1] if vv is not None else None, v) for k, v in aux[name].items()}
        except AttributeError: pass
      return aux

    def wipe_sg_aux(self):
      for name, aux_module in self.unet.named_modules():
        try: del aux_module._aux
        except AttributeError: pass

    @torch.no_grad()
    def __call__(
        self,
        prompt: Union[str, List[str]] = None,
        prompt_2: Optional[Union[str, List[str]]] = None,
        height: Optional[int] = None,
        width: Optional[int] = None,
        num_inference_steps: int = 50,
        denoising_end: Optional[float] = None,
        guidance_scale: float = 5.0,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        negative_prompt_2: Optional[Union[str, List[str]]] = None,
        num_images_per_prompt: Optional[int] = 1,
        eta: float = 0.0,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.FloatTensor] = None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
        callback_steps: int = 1,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        guidance_rescale: float = 0.0,
        original_size: Optional[Tuple[int, int]] = None,
        crops_coords_top_left: Tuple[int, int] = (0, 0),
        target_size: Optional[Tuple[int, int]] = None,
        sg_grad_wt = 1.0,
        sg_edits = None,
        sg_loss_rescale = 1000.0, #prevent fp16 underflow
        debug=False,
        feats_start=0,
        feats_end=-1,
    ):
        # 0. Default height and width to unet
        height = height or self.default_sample_size * self.vae_scale_factor
        width = width or self.default_sample_size * self.vae_scale_factor

        original_size = original_size or (height, width)
        target_size = target_size or (height, width)

        # 1. Check inputs. Raise error if not correct
        self.check_inputs(
            prompt,
            prompt_2,
            height,
            width,
            callback_steps,
            negative_prompt,
            negative_prompt_2,
            prompt_embeds,
            negative_prompt_embeds,
            pooled_prompt_embeds,
            negative_pooled_prompt_embeds,
        )

        # 2. Define call parameters
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        device = self._execution_device

        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.
        do_classifier_free_guidance = guidance_scale > 1.0
        do_self_guidance = sg_grad_wt > 0 and sg_edits is not None

        if do_self_guidance:
          prompt_text_ids = self.tokenizer(prompt, return_tensors='np')['input_ids'][0]
          for edit_key, edits in sg_edits.items():
            for edit in edits:
              if 'words' not in edit:
                edit['idxs'] = np.arange(len(prompt_text_ids))
              else:
                words = edit['words']
                if not isinstance(words, list): words = [words]
                idxs = []
                for word in words:
                  word_ids = self.tokenizer(word, return_tensors='np')['input_ids']
                  word_ids = word_ids[word_ids < 49406]
                  idxs.append(search_sequence_numpy(prompt_text_ids, word_ids))
                edit['idxs'] = np.concatenate(idxs)

        # 3. Encode input prompt
        text_encoder_lora_scale = (
            cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
        )
        (
            prompt_embeds,
            negative_prompt_embeds,
            pooled_prompt_embeds,
            negative_pooled_prompt_embeds,
        ) = self.encode_prompt(
            prompt=prompt,
            prompt_2=prompt_2,
            device=device,
            num_images_per_prompt=num_images_per_prompt,
            do_classifier_free_guidance=do_classifier_free_guidance,
            negative_prompt=negative_prompt,
            negative_prompt_2=negative_prompt_2,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            pooled_prompt_embeds=pooled_prompt_embeds,
            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
            lora_scale=text_encoder_lora_scale,
        )

        # 4. Prepare timesteps
        self.scheduler.set_timesteps(num_inference_steps, device=device)

        timesteps = self.scheduler.timesteps

        # 5. Prepare latent variables
        num_channels_latents = self.unet.config.in_channels
        latents = self.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            height,
            width,
            prompt_embeds.dtype,
            device,
            generator,
            latents,
        )

        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

        # 7. Prepare added time ids & embeddings
        add_text_embeds = pooled_prompt_embeds
        add_time_ids = self._get_add_time_ids(
            original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
        )

        if do_classifier_free_guidance:
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
            add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)

        prompt_embeds = prompt_embeds.to(device)
        add_text_embeds = add_text_embeds.to(device)
        add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)

        # 8. Denoising loop
        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)

        # 7.1 Apply denoising_end
        if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
            discrete_timestep_cutoff = int(
                round(
                    self.scheduler.config.num_train_timesteps
                    - (denoising_end * self.scheduler.config.num_train_timesteps)
                )
            )
            num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
            timesteps = timesteps[:num_inference_steps]

        self.wipe_sg_aux()
        torch.cuda.empty_cache()
        if feats_end < 0: feats_end = len(timesteps)
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                # torch.cuda.empty_cache()
                # expand the latents if we are doing classifier free guidance
                with torch.set_grad_enabled(do_self_guidance):#, torch.autograd.detect_anomaly():
                  latents.requires_grad_(do_self_guidance)
                  latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

                  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                  # predict the noise residual
                  added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
                  noise_pred = self.unet(
                      latent_model_input,
                      t,
                      encoder_hidden_states=prompt_embeds,
                      cross_attention_kwargs=cross_attention_kwargs,
                      added_cond_kwargs=added_cond_kwargs,
                      return_dict=False,
                  )[0]

                  # perform guidance
                  if do_classifier_free_guidance:
                      noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                      noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                  ### SELF GUIDANCE
                  if do_self_guidance and feats_start <= i < feats_end:
                    sg_aux = self.get_sg_aux(do_classifier_free_guidance)
                    sg_loss = 0
                    for edit_key, edits in sg_edits.items():
                      if isinstance(edit_key, str): key_aux = sg_aux[edit_key]
                      else: key_aux = {'': {k: sg_aux[k] for k in edit_key}}
                      for edit in edits:
                        wt = edit.get('weight', 1.)
                        if wt:
                          tgt = edit.get('tgt')
                          if tgt is not None:
                            if isinstance(edit_key, str): tgt = tgt[edit_key]
                            else: tgt = {'': {k: tgt[k] for k in edit_key}}
                          edit_loss = torch.stack([edit['fn'](v, i=i, idxs=edit['idxs'], **edit.get('kwargs',{}), tgt=tgt[k] if tgt is not None else None) for k,v in key_aux.items()]).mean()
                          sg_loss += wt * edit_loss
                    sg_grad = torch.autograd.grad(sg_loss_rescale * sg_loss, latents)[0] / sg_loss_rescale
                    if debug: print(f'Self guidance loss:{sg_loss}. gradient mean magnitude: {sg_grad.abs().mean()}. Pct zeros: {(sg_grad==0).sum()/torch.numel(sg_grad)}')
                    noise_pred = noise_pred + sg_grad_wt * sg_grad
                    assert not noise_pred.isnan().any()
                  latents.detach()
                  ### END SELF GUIDANCE

                if do_classifier_free_guidance and guidance_rescale > 0.0:
                    # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)

                # compute the previous noisy sample x_t -> x_t-1
                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

                # call the callback, if provided
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        callback(i, t, latents)


        torch.cuda.empty_cache()
        # if not save_aux: self.wipe_sg_aux()
        if not Globals.get('save_aux'): self.wipe_sg_aux()
        latents = latents.detach()
        # make sure the VAE is in float32 mode, as it overflows in float16
        if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
            self.upcast_vae()
            latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)

        if not output_type == "latent":
            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
        else:
            image = latents
            return StableDiffusionXLPipelineOutput(images=image)

        # apply watermark if available
        if self.watermark is not None:
            image = self.watermark.apply_watermark(image)

        image = self.image_processor.postprocess(image, output_type=output_type)

        # Offload last model to CPU
        if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
            self.final_offload_hook.offload()

        if not return_dict:
            return (image,)

        return StableDiffusionXLPipelineOutput(images=image)


class SelfGuidanceSD_v1_5_Pipeline(StableDiffusionPipeline):
    def get_sg_aux(self, cfg=True, transpose=True, get_whole=False):
      aux = defaultdict(dict)
      for name, aux_module in self.unet.named_modules():
        try:
          module_aux = aux_module._aux
          if transpose:
            for k, v in module_aux.items():
              if cfg: v = torch.utils._pytree.tree_map(lambda vv: vv.chunk(2)[1] if vv is not None else None, v)
              if get_whole:
                assert len(v) == 50
                v.append(sum([vv.cpu() for vv in v]) / len(v))
              aux[k][name] = v
          else:
            aux[name] = module_aux
            if cfg:
              aux[name] = {k: torch.utils._pytree.tree_map(lambda vv: vv.chunk(2)[1] if vv is not None else None, v) for k, v in aux[name].items()}
        except AttributeError: pass
      return aux

    def wipe_sg_aux(self):
      for name, aux_module in self.unet.named_modules():
        try: del aux_module._aux
        except AttributeError: pass

    def save_layout_attn(self, attn, step, words):
      image = attn
      image = 255 * image / image.max()
      image = image.unsqueeze(-1).expand(*image.shape, 3).cpu().detach()
      image = image.numpy().astype(np.uint8)
      image = Image.fromarray(image).resize((256, 256))
    #   image.save(os.path.join(attn_path, f"step_{step}_{words}_total.png"))
      image.save(os.path.join(Globals.get('attn_path'), f"step_{step}_{words}_total.png"))

    def register_hook(self):
      # down 1
      self.unet.down_blocks[1].attentions[0].register_forward_hook(partial(stash_to_aux,mode='output', key='feats_down_1_0'), with_kwargs=True)
      self.unet.down_blocks[1].attentions[0].transformer_blocks[0].attn2.register_forward_hook(partial(resave_aux_key, new_key='attn_down_1_0_0'))

      self.unet.down_blocks[1].attentions[1].register_forward_hook(partial(stash_to_aux,mode='output', key='feats_down_1_1'), with_kwargs=True)
      self.unet.down_blocks[1].attentions[1].transformer_blocks[0].attn2.register_forward_hook(partial(resave_aux_key, new_key='attn_down_1_1_0'))

      self.unet.down_blocks[1].register_forward_hook(partial(stash_to_aux,mode='output', key='feats_down_1'), with_kwargs=True)

      # down 2
      self.unet.down_blocks[2].attentions[0].register_forward_hook(partial(stash_to_aux,mode='output', key='feats_down_2_0'), with_kwargs=True)
      self.unet.down_blocks[2].attentions[0].transformer_blocks[0].attn2.register_forward_hook(partial(resave_aux_key, new_key='attn_down_2_0_0'))

      self.unet.down_blocks[2].attentions[1].register_forward_hook(partial(stash_to_aux,mode='output', key='feats_down_2_1'), with_kwargs=True)
      self.unet.down_blocks[2].attentions[1].transformer_blocks[0].attn2.register_forward_hook(partial(resave_aux_key, new_key='attn_down_2_1_0'))

      self.unet.down_blocks[2].register_forward_hook(partial(stash_to_aux,mode='output', key='feats_down_2'), with_kwargs=True)

      # up 1
      self.unet.up_blocks[1].attentions[0].register_forward_hook(partial(stash_to_aux,mode='output', key='feats_up_1_0'), with_kwargs=True)
      self.unet.up_blocks[1].attentions[0].transformer_blocks[0].attn2.register_forward_hook(partial(resave_aux_key, new_key='attn_up_1_0_0'))

      self.unet.up_blocks[1].attentions[1].register_forward_hook(partial(stash_to_aux,mode='output', key='feats_up_1_1'), with_kwargs=True)
      self.unet.up_blocks[1].attentions[1].transformer_blocks[0].attn2.register_forward_hook(partial(resave_aux_key, new_key='attn_up_1_1_0'))

      self.unet.up_blocks[1].attentions[2].register_forward_hook(partial(stash_to_aux,mode='output', key='feats_up_1_2'), with_kwargs=True)
      self.unet.up_blocks[1].attentions[2].transformer_blocks[0].attn2.register_forward_hook(partial(resave_aux_key, new_key='attn_up_1_2_0'))
      
      self.unet.up_blocks[1].register_forward_hook(partial(stash_to_aux,mode='output', key='feats_up_1'), with_kwargs=True)

      # up 2
      self.unet.up_blocks[2].attentions[0].register_forward_hook(partial(stash_to_aux,mode='output', key='feats_up_2_0'), with_kwargs=True)
      self.unet.up_blocks[2].attentions[0].transformer_blocks[0].attn2.register_forward_hook(partial(resave_aux_key, new_key='attn_up_2_0_0'))

      self.unet.up_blocks[2].attentions[1].register_forward_hook(partial(stash_to_aux,mode='output', key='feats_up_2_1'), with_kwargs=True)
      self.unet.up_blocks[2].attentions[1].transformer_blocks[0].attn2.register_forward_hook(partial(resave_aux_key, new_key='attn_up_2_1_0'))

      self.unet.up_blocks[2].attentions[2].register_forward_hook(partial(stash_to_aux,mode='output', key='feats_up_2_2', last=True), with_kwargs=True)
      self.unet.up_blocks[2].attentions[2].transformer_blocks[0].attn2.register_forward_hook(partial(resave_aux_key, new_key='attn_up_2_2_0'))
    
      self.unet.up_blocks[2].register_forward_hook(partial(stash_to_aux,mode='output', key='feats_up_2'), with_kwargs=True)

      # up 3
      self.unet.up_blocks[3].attentions[0].register_forward_hook(partial(stash_to_aux,mode='output', key='feats_up_3_0'), with_kwargs=True)
      self.unet.up_blocks[3].attentions[0].transformer_blocks[0].attn2.register_forward_hook(partial(resave_aux_key, new_key='attn_up_3_0_0'))

      self.unet.up_blocks[3].attentions[1].register_forward_hook(partial(stash_to_aux,mode='output', key='feats_up_3_1'), with_kwargs=True)
      self.unet.up_blocks[3].attentions[1].transformer_blocks[0].attn2.register_forward_hook(partial(resave_aux_key, new_key='attn_up_3_1_0'))

      self.unet.up_blocks[3].attentions[2].register_forward_hook(partial(stash_to_aux,mode='output', key='feats_up_3_2', last=True), with_kwargs=True)
      self.unet.up_blocks[3].attentions[2].transformer_blocks[0].attn2.register_forward_hook(partial(resave_aux_key, new_key='attn_up_3_2_0'))
      
      self.unet.up_blocks[3].register_forward_hook(partial(stash_to_aux,mode='output', key='feats_up_3'), with_kwargs=True)
    
    @torch.no_grad()
    def __call__(
        self,
        prompt: Union[str, List[str]] = None,
        height: Optional[int] = None,
        width: Optional[int] = None,
        num_inference_steps: int = 50,
        guidance_scale: float = 7.5,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_images_per_prompt: Optional[int] = 1,
        eta: float = 0.0,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.FloatTensor] = None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
        callback_steps: int = 1,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        guidance_rescale: float = 0.0,
        uncond_embeddings=None,
        sg_grad_wt = 1.0,
        sg_edits = None,
        sg_loss_rescale = 1000.0, #prevent fp16 underflow
        debug=False,
        feats_start=0,
        feats_end=-1,
    ):

        # 0. Default height and width to unet
        height = height or self.unet.config.sample_size * self.vae_scale_factor
        width = width or self.unet.config.sample_size * self.vae_scale_factor

        # 1. Check inputs. Raise error if not correct
        self.check_inputs(
            prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
        )

        # 2. Define call parameters
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        device = self._execution_device
        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.
        do_classifier_free_guidance = guidance_scale > 1.0
        do_self_guidance = sg_grad_wt > 0 and sg_edits is not None

        if do_self_guidance:
          prompt_text_ids = self.tokenizer(prompt, return_tensors='np')['input_ids'][0]
          for edit_key, edits in sg_edits.items():
            for edit in edits:
              if 'words' not in edit:
                edit['idxs'] = np.arange(len(prompt_text_ids))
              else:
                words = edit['words']
                if not isinstance(words, list): words = [words]
                idxs = []
                for word in words:
                  word_ids = self.tokenizer(word, return_tensors='np')['input_ids']
                  word_ids = word_ids[word_ids < 49406]
                  idxs.append(search_sequence_numpy(prompt_text_ids, word_ids))
                edit['idxs'] = np.concatenate(idxs)

        # 3. Encode input prompt
        text_encoder_lora_scale = (
            cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
        )
        prompt_embeds = self._encode_prompt(
            prompt,
            device,
            num_images_per_prompt,
            do_classifier_free_guidance,
            negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            lora_scale=text_encoder_lora_scale,
        )

        # 4. Prepare timesteps
        self.scheduler.set_timesteps(num_inference_steps, device=device)
        timesteps = self.scheduler.timesteps

        # 5. Prepare latent variables
        num_channels_latents = self.unet.config.in_channels
        latents = self.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            height,
            width,
            prompt_embeds.dtype,
            device,
            generator,
            latents,
        )

        self.gaussian_smoothing = GaussianSmoothing().to(device)

        text_input = self.tokenizer(
            prompt,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
        text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0]
        max_length = text_input.input_ids.shape[-1]
        if uncond_embeddings is None:
            uncond_input = self.tokenizer(
                [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
            )
            uncond_embeddings_ = self.text_encoder(uncond_input.input_ids.to(device))[0]
        else:
            uncond_embeddings_ = None

        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

        # 7. Denoising loop
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        self.wipe_sg_aux()
        torch.cuda.empty_cache()
        if feats_end < 0: feats_end = len(timesteps)
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                sg_loss = torch.tensor(10000)
                # global new_timestep, step, NO_OPT
                # step = i
                # new_timestep = True
                Globals.set('step', i)
                Globals.set('new_timestep', True)
                iteration = 0
                if uncond_embeddings_ is None:
                    context = torch.cat([uncond_embeddings[i].expand(*text_embeddings.shape), text_embeddings])
                else:
                    context = torch.cat([uncond_embeddings_, text_embeddings])
                if do_self_guidance and global_start <= i < global_end:
                  with torch.set_grad_enabled(True):
                    while sg_loss > 0.1 and iteration < num_iteration:
                      
                    #   NO_OPT = False
                      Globals.set('NO_OPT', False)

                      latents.requires_grad_(do_self_guidance)
                      # expand the latents if we are doing classifier free guidance
                      latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                      latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                      # predict the noise residual
                      noise_pred = self.unet(
                          latent_model_input,
                          t,
                          encoder_hidden_states=context,
                          cross_attention_kwargs=cross_attention_kwargs,
                          return_dict=False,
                      )[0]

                    #   new_timestep = False
                      Globals.set('new_timestep', False)
                      iteration += 1

                      # perform guidance
                      if do_classifier_free_guidance:
                          noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                          noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
                    
                      sg_aux = self.get_sg_aux(do_classifier_free_guidance)
                      app_loss = 0
                      shape_loss = 0
                      for edit_key, edits in sg_edits.items():
                        if isinstance(edit_key, str): key_aux = sg_aux[edit_key]
                        else: key_aux = {'': {k: sg_aux[k] for k in edit_key}}
                        for edit in edits:
                          wt = edit.get('weight', 1.)
                          if wt:
                            tgt = edit.get('tgt')
                            base_aux = edit.get('base_aux')
                            start, end = edit.get('timestep')
                            if not start <= i < end:
                            #   offset = min(abs(start - i), abs(i - end+1))
                            #   wt *= 1/2 / offset
                                continue
                            if tgt is not None:
                              if isinstance(edit_key, str): tgt = tgt[edit_key]
                              else: tgt = {'': {k: tgt[k] for k in edit_key}}
                            edit_loss = torch.stack([edit['fn'](v, i=i, idxs=edit['idxs'], **edit.get('kwargs',{}), tgt=tgt[k] if tgt is not None else None) for k,v in key_aux.items()]).mean()
                            if 'shape' in str(edit['fn']):
                              shape_loss += wt * edit_loss
                            else:
                              app_loss += wt * edit_loss

                      loss = shape_loss + app_loss
                      grad = torch.autograd.grad((sg_loss_rescale * loss), latents, retain_graph=True)[0] / sg_loss_rescale
                    #   mask = hard_mask_base
                      mask = Globals.get('hard_mask_base')
                      
                      if debug: print(f'total loss:{loss:.4f}, shape loss:{shape_loss:.4f}, app loss:{app_loss:.4f}, gradient mean magnitude: {grad.abs().mean():.4f}. min: {grad.abs().min():.4f}. max: {grad.abs().max():.4f}')
                      noise_pred = noise_pred + sg_grad_wt * grad
                      latents = latents - sg_grad_wt * grad
                      try:
                        assert not noise_pred.isnan().any()
                      except AssertionError as error:
                        print(error)
                  latents.detach()

                # NO_OPT = True
                Globals.set('NO_OPT', True)
                latents.requires_grad_(do_self_guidance)
                # expand the latents if we are doing classifier free guidance
                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                # predict the noise residual
                noise_pred = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=context,
                    cross_attention_kwargs=cross_attention_kwargs,
                    return_dict=False,
                )[0]

                save = False

                # perform guidance
                if do_classifier_free_guidance:
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                if do_classifier_free_guidance and guidance_rescale > 0.0:
                    # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)

                # compute the previous noisy sample x_t -> x_t-1
                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

                # call the callback, if provided
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        callback(i, t, latents)

        torch.cuda.empty_cache()
        # if not save_aux: self.wipe_sg_aux()
        if not Globals.get('save_aux'): self.wipe_sg_aux()
        latents = latents.detach()
        if not output_type == "latent":
            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
        else:
            image = latents
            has_nsfw_concept = None

        if has_nsfw_concept is None:
            do_denormalize = [True] * image.shape[0]
        else:
            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

        # Offload last model to CPU
        if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
            self.final_offload_hook.offload()

        if not return_dict:
            return (image, has_nsfw_concept)

        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
